#include "FORCE_STRESS.h"

#include "source_lcao/module_dftu/dftu.h" //Quxin add for DFT+U on 20201029
#include "source_pw/module_pwdft/global.h"
#include "source_io/output_log.h"
#include "source_io/module_parameter/parameter.h"
// new
#include "source_base/timer.h"
#include "source_cell/module_neighbor/sltk_grid_driver.h"
#include "source_estate/elecstate_lcao.h"
#include "source_estate/module_pot/H_TDDFT_pw.h"       // Taoni add 2025-02-20
#include "source_estate/module_pot/efield.h"           // liuyu add 2022-05-18
#include "source_estate/module_pot/gatefield.h"        // liuyu add 2022-09-13
#include "source_hamilt/module_surchem/surchem.h" //sunml add 2022-08-10
#include "source_hamilt/module_vdw/vdw.h"
#include "source_io/module_parameter/parameter.h"
#ifdef __MLALGO
#include "source_lcao/module_deepks/LCAO_deepks.h"    //caoyu add for deepks 2021-06-03
#include "source_lcao/module_deepks/LCAO_deepks_io.h" // mohan add 2024-07-22
#endif
#include "source_lcao/module_operator_lcao/dftu_lcao.h"
#include "source_lcao/module_operator_lcao/dspin_lcao.h"
#include "source_lcao/module_operator_lcao/nonlocal_new.h"


// mohan add 2025-11-04
template <>
void assign_dmk_ptr<double>(
    elecstate::DensityMatrix<double,double>* dm,
    std::vector<std::vector<double>>*& dmk_d,
    std::vector<std::vector<std::complex<double>>>*& dmk_c,
    bool gamma_only_local
) {
    auto& dmk_tmp = dm->get_DMK_vector();
    dmk_d = &dmk_tmp;
    dmk_c = nullptr;
}

template <>
void assign_dmk_ptr<std::complex<double>>(
    elecstate::DensityMatrix<std::complex<double>,double>* dm,
    std::vector<std::vector<double>>*& dmk_d,
    std::vector<std::vector<std::complex<double>>>*& dmk_c,
    bool gamma_only_local
) {
    auto& dmk_tmp = dm->get_DMK_vector();
    dmk_c = &dmk_tmp;
    dmk_d = nullptr;
}



template <typename T>
Force_Stress_LCAO<T>::Force_Stress_LCAO(Record_adj& ra, const int nat_in) : RA(&ra), nat(nat_in)
{
}
template <typename T>
Force_Stress_LCAO<T>::~Force_Stress_LCAO()
{
}
template <typename T>
void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
                                          const bool isforce,
                                          const bool isstress,
                                          const bool istestf,
                                          const bool istests,
                                          const Grid_Driver& gd,
                                          Parallel_Orbitals& pv,
                                          const elecstate::ElecState* pelec,
                                          LCAO_domain::Setup_DM<T> &dmat, // mohan add 2025-11-03
                                          const psi::Psi<T>* psi,
                                          const TwoCenterBundle& two_center_bundle,
                                          const LCAO_Orbitals& orb,
                                          ModuleBase::matrix& fcs,
                                          ModuleBase::matrix& scs,
                                          const pseudopot_cell_vl& locpp,
                                          const Structure_Factor& sf,
                                          const K_Vectors& kv,
                                          ModulePW::PW_Basis* rhopw,
										  surchem& solvent,
										  Plus_U &dftu, // mohan add 2025-11-07
                                          Setup_DeePKS<T>& deepks,
										  Exx_NAO<T> &exx_nao,
                                          ModuleSymmetry::Symmetry* symm)
{
    ModuleBase::TITLE("Force_Stress_LCAO", "getForceStress");
    ModuleBase::timer::tick("Force_Stress_LCAO", "getForceStress");

    if (!isforce && !isstress)
    {
        ModuleBase::timer::tick("Force_Stress_LCAO", "getForceStress");
        return;
    }

    const int nat = ucell.nat;

    ForceStressArrays fsr; // mohan add 2024-06-15

    // total force : ModuleBase::matrix fcs;

    // part of total force
    ModuleBase::matrix foverlap;
    ModuleBase::matrix ftvnl_dphi;
    ModuleBase::matrix fvnl_dbeta;
    ModuleBase::matrix fvl_dphi;
    ModuleBase::matrix fvl_dvl;
    ModuleBase::matrix fewalds;
    ModuleBase::matrix fcc;
    ModuleBase::matrix fscc;
    ModuleBase::matrix fvnl_dalpha; // deepks

    fvl_dphi.create(nat, 3); // must do it now, update it later, noted by zhengdy

    if (isforce)
    {
        fcs.create(nat, 3);
        foverlap.create(nat, 3); // overlap force
        ftvnl_dphi.create(nat, 3); // pulay force of NAO
        fvnl_dbeta.create(nat, 3); // pulay force of non-local projectors
        fvl_dvl.create(nat, 3); // force from local potentials
        fewalds.create(nat, 3); // Ewald force
        fcc.create(nat, 3); // force due to core correction
        fscc.create(nat, 3); // force due to self-consistent field
        fvnl_dalpha.create(nat, 3); // deepks

        // calculate basic terms in Force, same method with PW base
        this->calForcePwPart(ucell, fvl_dvl, fewalds, fcc, fscc, pelec->f_en.etxc,
              pelec->vnew, pelec->vnew_exist, pelec->charge, rhopw, locpp, sf);
    }

    // total stress : ModuleBase::matrix scs
    ModuleBase::matrix sigmacc;
    ModuleBase::matrix sigmadvl;
    ModuleBase::matrix sigmaewa;
    ModuleBase::matrix sigmaxc;
    ModuleBase::matrix sigmahar;
    ModuleBase::matrix soverlap;
    ModuleBase::matrix stvnl_dphi;
    ModuleBase::matrix svnl_dbeta;
    ModuleBase::matrix svl_dphi;
    ModuleBase::matrix svnl_dalpha; // deepks

    //! stress
    if (isstress)
    {
        scs.create(3, 3);
        sigmacc.create(3, 3);
        sigmadvl.create(3, 3);
        sigmaewa.create(3, 3);
        sigmaxc.create(3, 3);
        sigmahar.create(3, 3);

        soverlap.create(3, 3);
        stvnl_dphi.create(3, 3);
        svnl_dbeta.create(3, 3);
        svl_dphi.create(3, 3);
        svnl_dalpha.create(3, 3);

        // calculate basic terms in Stress, similar method with PW base
        this->calStressPwPart(ucell, sigmadvl, sigmahar, sigmaewa, sigmacc,
          sigmaxc, pelec->f_en.etxc, pelec->charge, rhopw, locpp, sf);
    }

    //! atomic forces from integration (4 terms)
    this->integral_part(PARAM.globalv.gamma_only_local, isforce, isstress,
                        ucell, gd, fsr, pelec, dmat.dm, psi, foverlap, ftvnl_dphi, // add dmat.dm, mohan 20251104
                        fvnl_dbeta, fvl_dphi, soverlap, stvnl_dphi, svnl_dbeta,
                        svl_dphi, fvnl_dalpha, svnl_dalpha, deepks, 
                        two_center_bundle, orb, pv, kv);

    // calculate force and stress for Nonlocal part
    if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2)
    {
        hamilt::NonlocalNew<hamilt::OperatorLCAO<T, double>> tmp_nonlocal(nullptr,
          kv.kvec_d, nullptr, &ucell, orb.cutoffs(), &gd, two_center_bundle.overlap_orb_beta.get());

        if (PARAM.inp.nspin == 2)
        {
            dmat.dm->switch_dmr(1);
        }

        const hamilt::HContainer<double>* dmr = dmat.dm->get_DMR_pointer(1);
        tmp_nonlocal.cal_force_stress(isforce, isstress, dmr, fvnl_dbeta, svnl_dbeta);
        if (PARAM.inp.nspin == 2)
        {
            dmat.dm->switch_dmr(0);
        }
    }
    else if (PARAM.inp.nspin == 4)
    {
        hamilt::NonlocalNew<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>> tmp_nonlocal(
            nullptr, kv.kvec_d, nullptr, &ucell, orb.cutoffs(), &gd,
            two_center_bundle.overlap_orb_beta.get());

        // calculate temporary complex DMR for nonlocal force&stress
        // In fact, only SOC part need the imaginary part of DMR for correct force&stress
        hamilt::HContainer<std::complex<double>> tmp_dmr(dmat.dm->get_DMR_pointer(1)->get_paraV());
        std::vector<int> ijrs = dmat.dm->get_DMR_pointer(1)->get_ijr_info();
        tmp_dmr.insert_ijrs(&ijrs);
        tmp_dmr.allocate();
        dmat.dm->cal_DMR_full(&tmp_dmr);
        tmp_nonlocal.cal_force_stress(isforce, isstress, &tmp_dmr, fvnl_dbeta, svnl_dbeta);
    }

    //! forces and stress from vdw
    //  Peize Lin add 2014-04-04, update 2021-03-09
    //  jiyy add 2019-05-18, update 2021-05-02
    ModuleBase::matrix force_vdw;
    ModuleBase::matrix stress_vdw;
    auto vdw_solver = vdw::make_vdw(ucell, PARAM.inp);
    if (vdw_solver != nullptr)
    {
        if (isforce)
        {
            force_vdw.create(nat, 3);
            const std::vector<ModuleBase::Vector3<double>>& force_vdw_temp = vdw_solver->get_force();
            for (int iat = 0; iat < ucell.nat; ++iat)
            {
                force_vdw(iat, 0) = force_vdw_temp[iat].x;
                force_vdw(iat, 1) = force_vdw_temp[iat].y;
                force_vdw(iat, 2) = force_vdw_temp[iat].z;
            }
        }
        if (isstress)
        {
            stress_vdw = vdw_solver->get_stress().to_matrix();
        }
    }

    //! forces from E-field
    ModuleBase::matrix fefield;
    if (PARAM.inp.efield_flag && isforce)
    {
        fefield.create(nat, 3);
        elecstate::Efield::compute_force(ucell, fefield);
    }

    //! atomic forces from E-field of rt-TDDFT
    ModuleBase::matrix fefield_tddft;
    if (PARAM.inp.esolver_type == "tddft" && isforce)
    {
        fefield_tddft.create(nat, 3);
        elecstate::H_TDDFT_pw::compute_force(ucell, fefield_tddft);
    }

    //! atomic forces from gate field
    ModuleBase::matrix fgate;
    if (PARAM.inp.gate_flag && isforce)
    {
        fgate.create(nat, 3);
        elecstate::Gatefield::compute_force(ucell, fgate);
    }

    //! atomic forces from implicit solvation model
    ModuleBase::matrix fsol;
    if (PARAM.inp.imp_sol && isforce)
    {
        fsol.create(nat, 3);
        solvent.cal_force_sol(ucell, rhopw, locpp.vloc, fsol);
    }

    //! atomic forces from DFT+U (Quxin version)
    ModuleBase::matrix force_u;
    ModuleBase::matrix stress_u;

    if (PARAM.inp.dft_plus_u) // Quxin add for DFT+U on 20201029
    {
        if (isforce)
        {
            force_u.create(nat, 3);
        }
        if (isstress)
        {
            stress_u.create(3, 3);
        }
        if (PARAM.inp.dft_plus_u == 2)
        {
			// dftu.force_stress(ucell, gd, pelec, pv, fsr, force_u, stress_u, kv);
			// mohan modify 2025-11-03
            std::vector<std::vector<double>>* dmk_d = nullptr;
            std::vector<std::vector<std::complex<double>>>* dmk_c = nullptr;
            // add a new template function
            assign_dmk_ptr<T>(dmat.dm, dmk_d, dmk_c, PARAM.globalv.gamma_only_local);
		    dftu.force_stress(ucell, gd, dmk_d, dmk_c, pv, fsr, force_u, stress_u, kv);
        }
        else
        {
            hamilt::DFTU<hamilt::OperatorLCAO<T, double>> tmpu(nullptr, // HK and SK are not used for force&stress
                                                                   kv.kvec_d,
                                                                   nullptr, // HR are not used for force&stress
                                                                   ucell,
                                                                   &gd,
                                                                   two_center_bundle.overlap_orb_onsite.get(),
                                                                   orb.cutoffs(),
                                                                   &dftu);

            tmpu.cal_force_stress(isforce, isstress, force_u, stress_u);
        }
    }

    // atomic force and stress for DeltaSpin
    ModuleBase::matrix force_dspin;
    ModuleBase::matrix stress_dspin;
    if (PARAM.inp.sc_mag_switch)
    {
        if (isforce)
        {
            force_dspin.create(nat, 3);
        }
        if (isstress)
        {
            stress_dspin.create(3, 3);
        }

        hamilt::DeltaSpin<hamilt::OperatorLCAO<T, double>> tmp_dspin(nullptr,
                                                                     kv.kvec_d,
                                                                     nullptr,
                                                                     ucell,
                                                                     &gd,
                                                                     two_center_bundle.overlap_orb_onsite.get(),
                                                                     orb.cutoffs());

        if (PARAM.inp.nspin == 2)
        {
            dmat.dm->switch_dmr(2);
        }
        const hamilt::HContainer<double>* dmr = dmat.dm->get_DMR_pointer(1);
        tmp_dspin.cal_force_stress(isforce, isstress, dmr, force_dspin, stress_dspin);
        if (PARAM.inp.nspin == 2)
        {
            dmat.dm->switch_dmr(0);
        }
    }

    if (!PARAM.globalv.gamma_only_local)
    {
        this->flk.finish_ftable(fsr);
    }

#ifdef __EXX
    // Force and Stress contribution from exx
    ModuleBase::matrix force_exx;
    ModuleBase::matrix stress_exx;
    if (GlobalC::exx_info.info_global.cal_exx)
    {
        if (isforce)
        {
            if (GlobalC::exx_info.info_ri.real_number)
            {
                exx_nao.exd->cal_exx_force(ucell.nat);
                force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd->get_force();
            }
            else
            {
                exx_nao.exc->cal_exx_force(ucell.nat);
                force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc->get_force();
            }
        }
        if (isstress)
        {
            if (GlobalC::exx_info.info_ri.real_number)
            {
                exx_nao.exd->cal_exx_stress(ucell.omega, ucell.lat0);
                stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exd->get_stress();
            }
            else
            {
                exx_nao.exc->cal_exx_stress(ucell.omega, ucell.lat0);
                stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_nao.exc->get_stress();
            }
        }
    }
#endif
    //--------------------------------
    // begin calculate and output force
    //--------------------------------
    if (isforce)
    {
        //---------------------------------
        // sum all parts of force!
        //---------------------------------
        for (int i = 0; i < 3; i++)
        {
            double sum = 0.0;

            for (int iat = 0; iat < nat; iat++)
            {
                fcs(iat, i) += foverlap(iat, i) + ftvnl_dphi(iat, i) + fvnl_dbeta(iat, i) + fvl_dphi(iat, i)
                               + fvl_dvl(iat, i) // derivative of local potential force (pw)
                               + fewalds(iat, i) // ewald force (pw)
                               + fcc(iat, i)     // nonlinear core correction force (pw)
                               + fscc(iat, i);   // self consistent corretion force (pw)

                // Force contribution from DFT+U, Quxin add on 20201029
                if (PARAM.inp.dft_plus_u)
                {
                    fcs(iat, i) += force_u(iat, i);
                }
                if (PARAM.inp.sc_mag_switch)
                {
                    fcs(iat, i) += force_dspin(iat, i);
                }
#ifdef __EXX
                // Force contribution from exx
                if (GlobalC::exx_info.info_global.cal_exx)
                {
                    fcs(iat, i) += force_exx(iat, i);
                }
#endif
                // VDW force of vdwd2 or vdwd3
                if (vdw_solver != nullptr)
                {
                    fcs(iat, i) += force_vdw(iat, i);
                }
                // E-field force
                if (PARAM.inp.efield_flag)
                {
                    fcs(iat, i) += fefield(iat, i);
                }
                // E-field force of tddft
                if (PARAM.inp.esolver_type == "tddft")
                {
                    fcs(iat, i) += fefield_tddft(iat, i);
                }
                // Gate field force
                if (PARAM.inp.gate_flag)
                {
                    fcs(iat, i) += fgate(iat, i);
                }
                // implicit solvation model
                if (PARAM.inp.imp_sol)
                {
                    fcs(iat, i) += fsol(iat, i);
                }
#ifdef __MLALGO
                // mohan add 2021-08-04
                if (PARAM.inp.deepks_scf)
                {
                    fcs(iat, i) += fvnl_dalpha(iat, i);
                }
#endif
                // sum total force for correction
                sum += fcs(iat, i);
            }

            if (!(PARAM.inp.gate_flag || PARAM.inp.efield_flag))
            {
                for (int iat = 0; iat < nat; ++iat)
                {
                    fcs(iat, i) -= sum / nat;
                }
            }
        }

        if (PARAM.inp.gate_flag || PARAM.inp.efield_flag)
        {
            GlobalV::ofs_running << "Atomic forces are not shifted if gate_flag or efield_flag == true!" << std::endl;
        }

        // pengfei 2016-12-20
        if (ModuleSymmetry::Symmetry::symm_flag == 1)
        {
            this->forceSymmetry(ucell, fcs, symm);
        }

        // compute forces using the DeePKS model
        deepks.write_forces(fcs, fvnl_dalpha, PARAM.inp);

        // print Rydberg force or not
        bool ry = false;
        if (istestf)
        {
            // test
            // ModuleBase::matrix fvlocal;
            // fvlocal.create(nat,3);
            ModuleBase::matrix ftvnl;
            ftvnl.create(nat, 3);
            for (int iat = 0; iat < nat; iat++)
            {
                for (int i = 0; i < 3; i++)
                {
                    // fvlocal(iat,i) = fvl_dphi(iat,i) + fvl_dvl(iat,i);
                    ftvnl(iat, i) = ftvnl_dphi(iat, i) + fvnl_dbeta(iat, i);
                }
            }

            GlobalV::ofs_running << "\n PARTS OF FORCE: " << std::endl;
            GlobalV::ofs_running << std::setiosflags(std::ios::showpos);
            GlobalV::ofs_running << std::setiosflags(std::ios::fixed) << std::setprecision(8) << std::endl;
            //-----------------------------
            // regular force terms test.
            //-----------------------------
            // this->print_force("OVERLAP    FORCE",foverlap,1,ry);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "OVERLAP    FORCE", foverlap, false);
            // this->print_force("TVNL_DPHI  force",ftvnl_dphi,PARAM.inp.test_force);
            // this->print_force("VNL_DBETA  force",fvnl_dbeta,PARAM.inp.test_force);
            // this->print_force("T_VNL      FORCE",ftvnl,1,ry);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "T_VNL      FORCE", ftvnl, false);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "VL_dPHI    FORCE", fvl_dphi, false);
            // this->print_force("VL_dPHI    FORCE",fvl_dphi,1,ry);
            // this->print_force("VL_dVL     FORCE",fvl_dvl,1,ry);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "VL_dVL     FORCE", fvl_dvl, false);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "EWALD      FORCE", fewalds, false);
            // this->print_force("VLOCAL     FORCE",fvlocal,PARAM.inp.test_force);
            // this->print_force("EWALD      FORCE",fewalds,1,ry);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "NLCC       FORCE", fcc, false);
            ModuleIO::print_force(GlobalV::ofs_running, ucell, "SCC        FORCE", fscc, false);
            // this->print_force("NLCC       FORCE",fcc,1,ry);
            // this->print_force("SCC        FORCE",fscc,1,ry);
            //-------------------------------
            // put extra force here for test!
            //-------------------------------
            if (PARAM.inp.efield_flag)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "EFIELD     FORCE", fefield, false);
                // this->print_force("EFIELD     FORCE",fefield,1,ry);
            }
            if (PARAM.inp.esolver_type == "tddft")
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "EFIELD_TDDFT     FORCE", fefield_tddft, false);
                // this->print_force("EFIELD_TDDFT     FORCE",fefield_tddft,1,ry);
            }
            if (PARAM.inp.gate_flag)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "GATEFIELD     FORCE", fgate, false);
                // this->print_force("GATEFIELD     FORCE",fgate,1,ry);
            }
            if (PARAM.inp.imp_sol)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "IMP_SOL     FORCE", fsol, false);
                // this->print_force("IMP_SOL     FORCE",fsol,1,ry);
            }
            if (vdw_solver != nullptr)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "VDW        FORCE", force_vdw, false);
                // this->print_force("VDW        FORCE",force_vdw,1,ry);
            }
            if (PARAM.inp.dft_plus_u)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "DFT+U      FORCE", force_u, false);
            }
            if (PARAM.inp.sc_mag_switch)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeltaSpin  FORCE", force_dspin, false);
            }
#ifdef __MLALGO
            // caoyu add 2021-06-03
            if (PARAM.inp.deepks_scf)
            {
                ModuleIO::print_force(GlobalV::ofs_running, ucell, "DeePKS 	FORCE", fvnl_dalpha, true);
            }
#endif
        }

        GlobalV::ofs_running << std::setiosflags(std::ios::left);

        // this->printforce_total(ry, istestf, fcs);
        ModuleIO::print_force(GlobalV::ofs_running, ucell, "TOTAL-FORCE (eV/Angstrom)", fcs, false);
        if (istestf)
        {
            GlobalV::ofs_running << "\n FORCE INVALID TABLE." << std::endl;
            GlobalV::ofs_running << " " << std::setw(8) << "atom" << std::setw(5) << "x" << std::setw(5) << "y"
                                 << std::setw(5) << "z" << std::endl;
            for (int iat = 0; iat < ucell.nat; iat++)
            {
                GlobalV::ofs_running << " " << std::setw(8) << iat;
                for (int i = 0; i < 3; i++)
                {
                    if (std::abs(fcs(iat, i) * ModuleBase::Ry_to_eV / ModuleBase::BOHR_TO_A)
                        < Force_Stress_LCAO::force_invalid_threshold_ev)
                    {
                        fcs(iat, i) = 0.0;
                        GlobalV::ofs_running << std::setw(5) << "1";
                    }
                    else
                    {
                        GlobalV::ofs_running << std::setw(5) << "0";
                    }
                }
                GlobalV::ofs_running << std::endl;
            }
        }
    } // end of force calculation
    //---------------------------------
    // begin calculate and output stress
    //---------------------------------
    if (isstress)
    {
        for (int i = 0; i < 3; i++)
        {
            for (int j = 0; j < 3; j++)
            {
                scs(i, j) += soverlap(i, j) + stvnl_dphi(i, j) + svnl_dbeta(i, j) + svl_dphi(i, j)
                             + sigmadvl(i, j)  // derivative of local potential stress (pw)
                             + sigmaewa(i, j)  // ewald stress (pw)
                             + sigmacc(i, j)   // nonlinear core correction stress (pw)
                             + sigmaxc(i, j)   // exchange corretion stress
                             + sigmahar(i, j); // hartree stress

                // VDW stress from linpz and jiyy
                if (vdw_solver != nullptr)
                {
                    scs(i, j) += stress_vdw(i, j);
                }
                // DFT plus U stress from qux
                if (PARAM.inp.dft_plus_u)
                {
                    scs(i, j) += stress_u(i, j);
                }
                if (PARAM.inp.sc_mag_switch)
                {
                    scs(i, j) += stress_dspin(i, j);
                }
#ifdef __EXX
                // Stress contribution from exx
                if (GlobalC::exx_info.info_global.cal_exx)
                {
                    scs(i, j) += stress_exx(i, j);
                }
#endif
#ifdef __MLALGO
                if (PARAM.inp.deepks_scf)
                {
                    scs(i, j) += svnl_dalpha(i, j);
                }
#endif
            }
        }
        if (ModuleSymmetry::Symmetry::symm_flag == 1)
        {
            symm->symmetrize_mat3(scs, ucell.lat);
        } // end symmetry

        deepks.write_stress(scs, svnl_dalpha, ucell.omega, PARAM.inp);

        // print Rydberg stress or not
        bool ry = false;

        // test stress each terms if needed
        if (istests)
        {
            // test
            ModuleBase::matrix svlocal;
            svlocal.create(3, 3);
            ModuleBase::matrix stvnl;
            stvnl.create(3, 3);
            for (int i = 0; i < 3; i++)
            {
                for (int j = 0; j < 3; j++)
                {
                    svlocal(i, j) = svl_dphi(i, j) + sigmadvl(i, j);
                    stvnl(i, j) = stvnl_dphi(i, j) + svnl_dbeta(i, j);
                }
            }

            const bool screen = PARAM.inp.test_stress;

            GlobalV::ofs_running << "\n PARTS OF STRESS: " << std::endl;
            GlobalV::ofs_running << std::setiosflags(std::ios::showpos);
            GlobalV::ofs_running << std::setiosflags(std::ios::fixed) << std::setprecision(10) << std::endl;
            ModuleIO::print_stress("OVERLAP  STRESS", soverlap, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("T        STRESS", stvnl_dphi, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("VNL      STRESS", svnl_dbeta, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("T_VNL    STRESS", stvnl, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("VL_dPHI  STRESS", svl_dphi, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("VL_dVL   STRESS", sigmadvl, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("HAR      STRESS", sigmahar, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("EWALD    STRESS", sigmaewa, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("cc       STRESS", sigmacc, screen, ry, GlobalV::ofs_running);
            ModuleIO::print_stress("XC       STRESS", sigmaxc, screen, ry, GlobalV::ofs_running);
            if (vdw_solver != nullptr)
            {
                ModuleIO::print_stress("VDW      STRESS", sigmaxc, screen, ry, GlobalV::ofs_running);
            }
            if (PARAM.inp.dft_plus_u)
            {
                ModuleIO::print_stress("DFTU     STRESS", stress_u, screen, ry, GlobalV::ofs_running);
            }
            if (PARAM.inp.sc_mag_switch)
            {
                ModuleIO::print_stress("DeltaSpin  STRESS", stress_dspin, screen, ry, GlobalV::ofs_running);
            }
            ModuleIO::print_stress("TOTAL    STRESS", scs, screen, ry, GlobalV::ofs_running);
        } // end of test

        GlobalV::ofs_running << std::setiosflags(std::ios::left);

        // print total stress
        bool screen_normal = true;
        ModuleIO::print_stress("TOTAL-STRESS", scs, screen_normal, ry, GlobalV::ofs_running);

        double unit_transform = 0.0;
        unit_transform = ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI, 3) * 1.0e-8;
        double external_stress[3] = {PARAM.inp.press1, PARAM.inp.press2, PARAM.inp.press3};

        for (int i = 0; i < 3; i++)
        {
            scs(i, i) -= external_stress[i] / unit_transform;
        }
    } // end of stress calculation

    ModuleBase::timer::tick("Force_Stress_LCAO", "getForceStress");
    return;
}

// local pseudopotential, ewald, core correction, scc terms in force
template <typename T>
void Force_Stress_LCAO<T>::calForcePwPart(UnitCell& ucell,
                                          ModuleBase::matrix& fvl_dvl,
                                          ModuleBase::matrix& fewalds,
                                          ModuleBase::matrix& fcc,
                                          ModuleBase::matrix& fscc,
                                          const double& etxc,
                                          const ModuleBase::matrix& vnew,
                                          const bool vnew_exist,
                                          const Charge* const chr,
                                          ModulePW::PW_Basis* rhopw,
                                          const pseudopot_cell_vl& locpp,
                                          const Structure_Factor& sf)
{
    ModuleBase::TITLE("Force_Stress_LCAO", "calForcePwPart");
#ifdef __CUDA
    if(PARAM.inp.device == "gpu")
    {
        Forces<double, base_device::DEVICE_GPU> f_pw(nat);
        f_pw.cal_force_loc(ucell, fvl_dvl, rhopw, locpp.vloc, chr);
        f_pw.cal_force_ew(ucell, fewalds, rhopw, &sf);
        f_pw.cal_force_cc(fcc, rhopw, chr, locpp.numeric, ucell);
        f_pw.cal_force_scc(fscc, rhopw, vnew, vnew_exist, locpp.numeric, ucell);
	}
	else
#endif
    {
        Forces<double, base_device::DEVICE_CPU> f_pw(nat);
        f_pw.cal_force_loc(ucell, fvl_dvl, rhopw, locpp.vloc, chr);
        f_pw.cal_force_ew(ucell, fewalds, rhopw, &sf);
        f_pw.cal_force_cc(fcc, rhopw, chr, locpp.numeric, ucell);
        f_pw.cal_force_scc(fscc, rhopw, vnew, vnew_exist, locpp.numeric, ucell);
    }

    return;
}

// overlap, kinetic, nonlocal pseudopotential, Local potential terms in force and stress
template <>
void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
		const bool isforce,
		const bool isstress,
		const UnitCell& ucell,
		const Grid_Driver& gd,
		ForceStressArrays& fsr, // mohan add 2024-06-15
		const elecstate::ElecState* pelec,
		const elecstate::DensityMatrix<double, double>* dm, // mohan add 2025-11-04
		const psi::Psi<double>* psi,
		ModuleBase::matrix& foverlap,
		ModuleBase::matrix& ftvnl_dphi,
		ModuleBase::matrix& fvnl_dbeta,
		ModuleBase::matrix& fvl_dphi,
		ModuleBase::matrix& soverlap,
		ModuleBase::matrix& stvnl_dphi,
		ModuleBase::matrix& svnl_dbeta,
		ModuleBase::matrix& svl_dphi,
		ModuleBase::matrix& fvnl_dalpha,
		ModuleBase::matrix& svnl_dalpha,
		Setup_DeePKS<double>& deepks,
		const TwoCenterBundle& two_center_bundle,
		const LCAO_Orbitals& orb,
		const Parallel_Orbitals& pv,
		const K_Vectors& kv)
{

    flk.ftable(isforce, isstress, fsr, ucell, gd, psi, pelec, dm,
               foverlap, ftvnl_dphi, fvnl_dbeta, fvl_dphi,
               soverlap, stvnl_dphi, svnl_dbeta, svl_dphi,
               fvnl_dalpha, svnl_dalpha, deepks, two_center_bundle, orb, pv);
    return;
}

template <>
void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOnly,
		const bool isforce,
		const bool isstress,
		const UnitCell& ucell,
		const Grid_Driver& gd,
		ForceStressArrays& fsr, // mohan add 2024-06-15
		const elecstate::ElecState* pelec,
		const elecstate::DensityMatrix<std::complex<double>, double>* dm, // mohan add 2025-11-04
		const psi::Psi<std::complex<double>>* psi,
		ModuleBase::matrix& foverlap,
		ModuleBase::matrix& ftvnl_dphi,
		ModuleBase::matrix& fvnl_dbeta,
		ModuleBase::matrix& fvl_dphi,
		ModuleBase::matrix& soverlap,
		ModuleBase::matrix& stvnl_dphi,
		ModuleBase::matrix& svnl_dbeta,
		ModuleBase::matrix& svl_dphi,
		ModuleBase::matrix& fvnl_dalpha,
		ModuleBase::matrix& svnl_dalpha,
		Setup_DeePKS<std::complex<double>>& deepks,
		const TwoCenterBundle& two_center_bundle,
		const LCAO_Orbitals& orb,
		const Parallel_Orbitals& pv,
		const K_Vectors& kv)
{
    flk.ftable(isforce, isstress, fsr, ucell, gd, psi, pelec, dm,
               foverlap, ftvnl_dphi, fvnl_dbeta, fvl_dphi,
               soverlap, stvnl_dphi, svnl_dbeta, svl_dphi,
               fvnl_dalpha, svnl_dalpha, deepks,
               two_center_bundle, orb, pv, &kv, this->RA);
    return;
}

// vlocal, hartree, ewald, core correction, exchange-correlation terms in stress
template <typename T>
void Force_Stress_LCAO<T>::calStressPwPart(UnitCell& ucell,
                                           ModuleBase::matrix& sigmadvl,
                                           ModuleBase::matrix& sigmahar,
                                           ModuleBase::matrix& sigmaewa,
                                           ModuleBase::matrix& sigmacc,
                                           ModuleBase::matrix& sigmaxc,
                                           const double& etxc,
                                           const Charge* const chr,
                                           ModulePW::PW_Basis* rhopw,
                                           const pseudopot_cell_vl& locpp,
                                           const Structure_Factor& sf)
{
    ModuleBase::TITLE("Force_Stress_LCAO", "calStressPwPart");

    // local pseudopotential stress:
    sc_pw.stress_loc(ucell, sigmadvl, rhopw, locpp.vloc, &sf, 0, chr);

    // hartree term
    sc_pw.stress_har(ucell, sigmahar, rhopw, 0, chr);

    // ewald stress: use plane wave only.
    sc_pw.stress_ewa(ucell, sigmaewa, rhopw, 0); // remain problem

    // stress due to core correlation.
    sc_pw.stress_cc(sigmacc, rhopw, ucell, &sf, 0, locpp.numeric, chr);

    // stress due to self-consistent charge.
    for (int i = 0; i < 3; i++)
    {
        sigmaxc(i, i) = -etxc / ucell.omega;
    }
    // Exchange-correlation for PBE
    sc_pw.stress_gga(ucell, sigmaxc, rhopw, chr);

    return;
}

#include "source_base/mathzone.h"
// do symmetry for total force
template <typename T>
void Force_Stress_LCAO<T>::forceSymmetry(const UnitCell& ucell, ModuleBase::matrix& fcs, ModuleSymmetry::Symmetry* symm)
{
    double d1, d2, d3;
    for (int iat = 0; iat < ucell.nat; iat++)
    {
        ModuleBase::Mathzone::Cartesian_to_Direct(fcs(iat, 0), fcs(iat, 1), fcs(iat, 2),
          ucell.a1.x, ucell.a1.y, ucell.a1.z, ucell.a2.x, ucell.a2.y, ucell.a2.z,
          ucell.a3.x, ucell.a3.y, ucell.a3.z, d1, d2, d3);

        fcs(iat, 0) = d1;
        fcs(iat, 1) = d2;
        fcs(iat, 2) = d3;
    }
    symm->symmetrize_vec3_nat(fcs.c);
    for (int iat = 0; iat < ucell.nat; iat++)
    {
        ModuleBase::Mathzone::Direct_to_Cartesian(fcs(iat, 0), fcs(iat, 1), fcs(iat, 2),
          ucell.a1.x, ucell.a1.y, ucell.a1.z, ucell.a2.x, ucell.a2.y, ucell.a2.z,
          ucell.a3.x, ucell.a3.y, ucell.a3.z, d1, d2, d3);

        fcs(iat, 0) = d1;
        fcs(iat, 1) = d2;
        fcs(iat, 2) = d3;
    }
    return;
}

template class Force_Stress_LCAO<double>;
template class Force_Stress_LCAO<std::complex<double>>;
